{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#all_slow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *\n",
    "from fastai.vision.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp vision.gan\n",
    "#default_cls_lvl 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GAN\n",
    "\n",
    "> Basic support for [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "GAN stands for [Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661.pdf) and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).\n",
    "\n",
    "We train them against each other in the sense that at each step (more or less), we:\n",
    "1. Freeze the generator and train the critic for one step by:\n",
    "  - getting one batch of true images (let's call that `real`)\n",
    "  - generating one batch of fake images (let's call that `fake`)\n",
    "  - have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones\n",
    "  - update the weights of the critic with the gradients of this loss\n",
    "  \n",
    "  \n",
    "2. Freeze the critic and train the generator for one step by:\n",
    "  - generating one batch of fake images\n",
    "  - evaluate the critic on it\n",
    "  - return a loss that rewards positively the critic thinking those are real images\n",
    "  - update the weights of the generator with the gradients of this loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Note: The fastai library provides support for training GANs through the GANTrainer, but doesn't include more than basic models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Wrapping the modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class GANModule(Module):\n",
    "    \"Wrapper around a `generator` and a `critic` to create a GAN.\"\n",
    "    def __init__(self, generator=None, critic=None, gen_mode=False):\n",
    "        if generator is not None: self.generator=generator\n",
    "        if critic    is not None: self.critic   =critic\n",
    "        store_attr('gen_mode')\n",
    "\n",
    "    def forward(self, *args):\n",
    "        return self.generator(*args) if self.gen_mode else self.critic(*args)\n",
    "\n",
    "    def switch(self, gen_mode=None):\n",
    "        \"Put the module in generator mode if `gen_mode`, in critic mode otherwise.\"\n",
    "        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is just a shell to contain the two models. When called, it will either delegate the input to the `generator` or the `critic` depending of the value of `gen_mode`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"GANModule.switch\" class=\"doc_header\"><code>GANModule.switch</code><a href=\"__main__.py#L12\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>GANModule.switch</code>(**`gen_mode`**=*`None`*)\n",
       "\n",
       "Put the module in generator mode if `gen_mode`, in critic mode otherwise."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(GANModule.switch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default (leaving `gen_mode` to `None`), this will put the module in the other mode (critic mode if it was in generator mode and vice versa)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(ConvLayer.__init__)\n",
    "def basic_critic(in_size, n_channels, n_features=64, n_extra_layers=0, norm_type=NormType.Batch, **kwargs):\n",
    "    \"A basic critic for images `n_channels` x `in_size` x `in_size`.\"\n",
    "    layers = [ConvLayer(n_channels, n_features, 4, 2, 1, norm_type=None, **kwargs)]\n",
    "    cur_size, cur_ftrs = in_size//2, n_features\n",
    "    layers += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, norm_type=norm_type, **kwargs) for _ in range(n_extra_layers)]\n",
    "    while cur_size > 4:\n",
    "        layers.append(ConvLayer(cur_ftrs, cur_ftrs*2, 4, 2, 1, norm_type=norm_type, **kwargs))\n",
    "        cur_ftrs *= 2 ; cur_size //= 2\n",
    "    init = kwargs.get('init', nn.init.kaiming_normal_)\n",
    "    layers += [init_default(nn.Conv2d(cur_ftrs, 1, 4, padding=0), init), Flatten()]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AddChannels(Module):\n",
    "    \"Add `n_dim` channels at the end of the input.\"\n",
    "    def __init__(self, n_dim): self.n_dim=n_dim\n",
    "    def forward(self, x): return x.view(*(list(x.shape)+[1]*self.n_dim))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(ConvLayer.__init__)\n",
    "def basic_generator(out_size, n_channels, in_sz=100, n_features=64, n_extra_layers=0, **kwargs):\n",
    "    \"A basic generator from `in_sz` to images `n_channels` x `out_size` x `out_size`.\"\n",
    "    cur_size, cur_ftrs = 4, n_features//2\n",
    "    while cur_size < out_size:  cur_size *= 2; cur_ftrs *= 2\n",
    "    layers = [AddChannels(2), ConvLayer(in_sz, cur_ftrs, 4, 1, transpose=True, **kwargs)]\n",
    "    cur_size = 4\n",
    "    while cur_size < out_size // 2:\n",
    "        layers.append(ConvLayer(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True, **kwargs))\n",
    "        cur_ftrs //= 2; cur_size *= 2\n",
    "    layers += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **kwargs) for _ in range(n_extra_layers)]\n",
    "    layers += [nn.ConvTranspose2d(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "critic = basic_critic(64, 3)\n",
    "generator = basic_generator(64, 3)\n",
    "tst = GANModule(critic=critic, generator=generator)\n",
    "real = torch.randn(2, 3, 64, 64)\n",
    "real_p = tst(real)\n",
    "test_eq(real_p.shape, [2,1])\n",
    "\n",
    "tst.switch() #tst is now in generator mode\n",
    "noise = torch.randn(2, 100)\n",
    "fake = tst(noise)\n",
    "test_eq(fake.shape, real.shape)\n",
    "\n",
    "tst.switch() #tst is back in critic mode\n",
    "fake_p = tst(fake)\n",
    "test_eq(fake_p.shape, [2,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_conv_args = dict(act_cls = partial(nn.LeakyReLU, negative_slope=0.2), norm_type=NormType.Spectral)\n",
    "\n",
    "def _conv(ni, nf, ks=3, stride=1, self_attention=False, **kwargs):\n",
    "    if self_attention: kwargs['xtra'] = SelfAttention(nf)\n",
    "    return ConvLayer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(ConvLayer)\n",
    "def DenseResBlock(nf, norm_type=NormType.Batch, **kwargs):\n",
    "    \"Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`.\"\n",
    "    return SequentialEx(ConvLayer(nf, nf, norm_type=norm_type, **kwargs),\n",
    "                        ConvLayer(nf, nf, norm_type=norm_type, **kwargs),\n",
    "                        MergeLayer(dense=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def gan_critic(n_channels=3, nf=128, n_blocks=3, p=0.15):\n",
    "    \"Critic to train a `GAN`.\"\n",
    "    layers = [\n",
    "        _conv(n_channels, nf, ks=4, stride=2),\n",
    "        nn.Dropout2d(p/2),\n",
    "        DenseResBlock(nf, **_conv_args)]\n",
    "    nf *= 2 # after dense block\n",
    "    for i in range(n_blocks):\n",
    "        layers += [\n",
    "            nn.Dropout2d(p),\n",
    "            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]\n",
    "        nf *= 2\n",
    "    layers += [\n",
    "        ConvLayer(nf, 1, ks=4, bias=False, padding=0, norm_type=NormType.Spectral, act_cls=None),\n",
    "        Flatten()]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class GANLoss(GANModule):\n",
    "    \"Wrapper around `crit_loss_func` and `gen_loss_func`\"\n",
    "    def __init__(self, gen_loss_func, crit_loss_func, gan_model):\n",
    "        super().__init__()\n",
    "        store_attr('gen_loss_func,crit_loss_func,gan_model')\n",
    "\n",
    "    def generator(self, output, target):\n",
    "        \"Evaluate the `output` with the critic then uses `self.gen_loss_func`\"\n",
    "        fake_pred = self.gan_model.critic(output)\n",
    "        self.gen_loss = self.gen_loss_func(fake_pred, output, target)\n",
    "        return self.gen_loss\n",
    "\n",
    "    def critic(self, real_pred, input):\n",
    "        \"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`.\"\n",
    "        fake = self.gan_model.generator(input).requires_grad_(False)\n",
    "        fake_pred = self.gan_model.critic(fake)\n",
    "        self.crit_loss = self.crit_loss_func(real_pred, fake_pred)\n",
    "        return self.crit_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In generator mode, this loss function expects the `output` of the generator and some `target` (a batch of real images). It will evaluate if the generator successfully fooled the critic using `gen_loss_func`. This loss function has the following signature\n",
    "``` \n",
    "def gen_loss_func(fake_pred, output, target):\n",
    "```\n",
    "to be able to combine the output of the critic on `output` (which the first argument `fake_pred`) with `output` and `target` (if you want to mix the GAN loss with other losses for instance).\n",
    "\n",
    "In critic mode, this loss function expects the `real_pred` given by the critic and some `input` (the noise fed to the generator). It will evaluate the critic using `crit_loss_func`. This loss function has the following signature\n",
    "``` \n",
    "def crit_loss_func(real_pred, fake_pred):\n",
    "```\n",
    "where `real_pred` is the output of the critic on a batch of real images and `fake_pred` is generated from the noise using the generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AdaptiveLoss(Module):\n",
    "    \"Expand the `target` to match the `output` size before applying `crit`.\"\n",
    "    def __init__(self, crit): self.crit = crit\n",
    "    def forward(self, output, target):\n",
    "        return self.crit(output, target[:,None].expand_as(output).float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def accuracy_thresh_expand(y_pred, y_true, thresh=0.5, sigmoid=True):\n",
    "    \"Compute accuracy after expanding `y_true` to the size of `y_pred`.\"\n",
    "    if sigmoid: y_pred = y_pred.sigmoid()\n",
    "    return ((y_pred>thresh).byte()==y_true[:,None].expand_as(y_pred).byte()).float().mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Callbacks for GAN training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def set_freeze_model(m, rg):\n",
    "    for p in m.parameters(): p.requires_grad_(rg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class GANTrainer(Callback):\n",
    "    \"Handles GAN Training.\"\n",
    "    run_after = TrainEvalCallback\n",
    "    def __init__(self, switch_eval=False, clip=None, beta=0.98, gen_first=False, show_img=True):\n",
    "        store_attr('switch_eval,clip,gen_first,show_img')\n",
    "        self.gen_loss,self.crit_loss = AvgSmoothLoss(beta=beta),AvgSmoothLoss(beta=beta)\n",
    "\n",
    "    def _set_trainable(self):\n",
    "        train_model = self.generator if     self.gen_mode else self.critic\n",
    "        loss_model  = self.generator if not self.gen_mode else self.critic\n",
    "        set_freeze_model(train_model, True)\n",
    "        set_freeze_model(loss_model, False)\n",
    "        if self.switch_eval:\n",
    "            train_model.train()\n",
    "            loss_model.eval()\n",
    "\n",
    "    def before_fit(self):\n",
    "        \"Initialize smootheners.\"\n",
    "        self.generator,self.critic = self.model.generator,self.model.critic\n",
    "        self.gen_mode = self.gen_first\n",
    "        self.switch(self.gen_mode)\n",
    "        self.crit_losses,self.gen_losses = [],[]\n",
    "        self.gen_loss.reset() ; self.crit_loss.reset()\n",
    "        #self.recorder.no_val=True\n",
    "        #self.recorder.add_metric_names(['gen_loss', 'disc_loss'])\n",
    "        #self.imgs,self.titles = [],[]\n",
    "\n",
    "    def before_validate(self):\n",
    "        \"Switch in generator mode for showing results.\"\n",
    "        self.switch(gen_mode=True)\n",
    "\n",
    "    def before_batch(self):\n",
    "        \"Clamp the weights with `self.clip` if it's not None, set the correct input/target.\"\n",
    "        if self.training and self.clip is not None:\n",
    "            for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)\n",
    "        if not self.gen_mode:\n",
    "            (self.learn.xb,self.learn.yb) = (self.yb,self.xb)\n",
    "\n",
    "    def after_batch(self):\n",
    "        \"Record `last_loss` in the proper list.\"\n",
    "        if not self.training: return\n",
    "        if self.gen_mode:\n",
    "            self.gen_loss.accumulate(self.learn)\n",
    "            self.gen_losses.append(self.gen_loss.value)\n",
    "            self.last_gen = self.learn.to_detach(self.pred)\n",
    "        else:\n",
    "            self.crit_loss.accumulate(self.learn)\n",
    "            self.crit_losses.append(self.crit_loss.value)\n",
    "\n",
    "    def before_epoch(self):\n",
    "        \"Put the critic or the generator back to eval if necessary.\"\n",
    "        self.switch(self.gen_mode)\n",
    "\n",
    "    #def after_epoch(self):\n",
    "    #    \"Show a sample image.\"\n",
    "    #    if not hasattr(self, 'last_gen') or not self.show_img: return\n",
    "    #    data = self.learn.data\n",
    "    #    img = self.last_gen[0]\n",
    "    #    norm = getattr(data,'norm',False)\n",
    "    #    if norm and norm.keywords.get('do_y',False): img = data.denorm(img)\n",
    "    #    img = data.train_ds.y.reconstruct(img)\n",
    "    #    self.imgs.append(img)\n",
    "    #    self.titles.append(f'Epoch {epoch}')\n",
    "    #    pbar.show_imgs(self.imgs, self.titles)\n",
    "    #    return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])\n",
    "\n",
    "    def switch(self, gen_mode=None):\n",
    "        \"Switch the model and loss function, if `gen_mode` is provided, in the desired mode.\"\n",
    "        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode\n",
    "        self._set_trainable()\n",
    "        self.model.switch(gen_mode)\n",
    "        self.loss_func.switch(gen_mode)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Warning: The GANTrainer is useless on its own, you need to complete it with one of the following switchers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class FixedGANSwitcher(Callback):\n",
    "    \"Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator.\"\n",
    "    run_after = GANTrainer\n",
    "    def __init__(self, n_crit=1, n_gen=1): store_attr('n_crit,n_gen')\n",
    "    def before_train(self): self.n_c,self.n_g = 0,0\n",
    "\n",
    "    def after_batch(self):\n",
    "        \"Switch the model if necessary.\"\n",
    "        if not self.training: return\n",
    "        if self.learn.gan_trainer.gen_mode:\n",
    "            self.n_g += 1\n",
    "            n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g\n",
    "        else:\n",
    "            self.n_c += 1\n",
    "            n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c\n",
    "        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)\n",
    "        if target == n_out:\n",
    "            self.learn.gan_trainer.switch()\n",
    "            self.n_c,self.n_g = 0,0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AdaptiveGANSwitcher(Callback):\n",
    "    \"Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`.\"\n",
    "    run_after = GANTrainer\n",
    "    def __init__(self, gen_thresh=None, critic_thresh=None):\n",
    "        store_attr('gen_thresh,critic_thresh')\n",
    "\n",
    "    def after_batch(self):\n",
    "        \"Switch the model if necessary.\"\n",
    "        if not self.training: return\n",
    "        if self.gan_trainer.gen_mode:\n",
    "            if self.gen_thresh is None or self.loss < self.gen_thresh: self.gan_trainer.switch()\n",
    "        else:\n",
    "            if self.critic_thresh is None or self.loss < self.critic_thresh: self.gan_trainer.switch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class GANDiscriminativeLR(Callback):\n",
    "    \"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic.\"\n",
    "    run_after = GANTrainer\n",
    "    def __init__(self, mult_lr=5.): self.mult_lr = mult_lr\n",
    "\n",
    "    def before_batch(self):\n",
    "        \"Multiply the current lr if necessary.\"\n",
    "        if not self.learn.gan_trainer.gen_mode and self.training:\n",
    "            self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']*self.mult_lr)\n",
    "\n",
    "    def after_batch(self):\n",
    "        \"Put the LR back to its value if necessary.\"\n",
    "        if not self.learn.gan_trainer.gen_mode: self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']/self.mult_lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GAN data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class InvisibleTensor(TensorBase):\n",
    "    def show(self, ctx=None, **kwargs): return ctx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def generate_noise(fn, size=100): return cast(torch.randn(size), InvisibleTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_batch(x:InvisibleTensor, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)\n",
    "    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x:InvisibleTensor, y:TensorImage, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize)\n",
    "    ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 128\n",
    "size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dblock = DataBlock(blocks = (TransformBlock, ImageBlock),\n",
    "                   get_x = generate_noise,\n",
    "                   get_items = get_image_files,\n",
    "                   splitter = IndexSplitter([]),\n",
    "                   item_tfms=Resize(size, method=ResizeMethod.Crop), \n",
    "                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.LSUN_BEDROOMS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dblock.dataloaders(path, path=path, bs=bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls.show_batch(max_n=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GAN Learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def gan_loss_from_func(loss_gen, loss_crit, weights_gen=None):\n",
    "    \"Define loss functions for a GAN from `loss_gen` and `loss_crit`.\"\n",
    "    def _loss_G(fake_pred, output, target, weights_gen=weights_gen):\n",
    "        ones = fake_pred.new_ones(fake_pred.shape[0])\n",
    "        weights_gen = ifnone(weights_gen, (1.,1.))\n",
    "        return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)\n",
    "\n",
    "    def _loss_C(real_pred, fake_pred):\n",
    "        ones  = real_pred.new_ones (real_pred.shape[0])\n",
    "        zeros = fake_pred.new_zeros(fake_pred.shape[0])\n",
    "        return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2\n",
    "\n",
    "    return _loss_G, _loss_C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _tk_mean(fake_pred, output, target): return fake_pred.mean()\n",
    "def _tk_diff(real_pred, fake_pred): return real_pred.mean() - fake_pred.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates()\n",
    "class GANLearner(Learner):\n",
    "    \"A `Learner` suitable for GANs.\"\n",
    "    def __init__(self, dls, generator, critic, gen_loss_func, crit_loss_func, switcher=None, gen_first=False,\n",
    "                 switch_eval=True, show_img=True, clip=None, cbs=None, metrics=None, **kwargs):\n",
    "        gan = GANModule(generator, critic)\n",
    "        loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)\n",
    "        if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)\n",
    "        trainer = GANTrainer(clip=clip, switch_eval=switch_eval, gen_first=gen_first, show_img=show_img)\n",
    "        cbs = L(cbs) + L(trainer, switcher)\n",
    "        metrics = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))\n",
    "        super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def from_learners(cls, gen_learn, crit_learn, switcher=None, weights_gen=None, **kwargs):\n",
    "        \"Create a GAN from `learn_gen` and `learn_crit`.\"\n",
    "        losses = gan_loss_from_func(gen_learn.loss_func, crit_learn.loss_func, weights_gen=weights_gen)\n",
    "        return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def wgan(cls, dls, generator, critic, switcher=None, clip=0.01, switch_eval=False, **kwargs):\n",
    "        \"Create a WGAN from `data`, `generator` and `critic`.\"\n",
    "        return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)\n",
    "\n",
    "GANLearner.from_learners = delegates(to=GANLearner.__init__)(GANLearner.from_learners)\n",
    "GANLearner.wgan = delegates(to=GANLearner.__init__)(GANLearner.wgan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.callback.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = basic_generator(64, n_channels=3, n_extra_layers=1)\n",
    "critic    = basic_critic   (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.recorder.train_metrics=True\n",
    "learn.recorder.valid_metrics=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(1, 2e-4, wd=0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.show_results(max_n=9, ds_idx=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
